--- title: Data augmentation in computer vision keywords: fastai sidebar: home_sidebar summary: "Transforms to apply data augmentation in Computer Vision" ---
img = PILImage(PILImage.create(TEST_IMAGE).resize((600,400)))
As for all Transform you can pass encodes and decodes at init or subclass and implement them. You can do the same for the randomize method that is called at each __call__. Note that to have a consistent state for inputs and targets, a RandTransform must be applied at the tuple level.
By default the randomize behavior is to execute the transform with probability p (if subclassing and wanting to tweak that behavior, the attribute self.do, if it exists, is looked for to decide if the transform is executed or not).
Note: A
RandTransformis only applied to the training set by default, so you have to passfilt=0if you are calling it directly and not through aDatasource. That behavior can be changed by setting the class variablefiltof the transform toNone.
def _add1(x): return x+1
dumb_tfm = RandTransform(_add1, p=0.5)
start = 2
for _ in range(10):
t = dumb_tfm(start, filt=0)
if dumb_tfm.do: test_eq(t, start+1)
else: test_eq(t, start)
PIL tranforms need to run before ImageToByteTensor (that has an order of 10).
flip = PILFlip(p=1.)
_,axs = plt.subplots(1,2, figsize=(10,4))
show_image(img, ctx=axs[0], title='original')
show_image(flip(img, filt=0), ctx=axs[1], title='flipped');
t = _pnt2tensor([[1,0], [2,1]], (3,3))
x = _tensor2pil(t)
y = flip(x, filt=0)
test_eq(tensor(array(y)), _pnt2tensor([[1,0], [0,1]], (3,3)))
pnts = TensorPoint(tensor([[1.,0], [2,1]]) -1)
test_eq(flip(pnts, filt=0), tensor([[1.,0], [0,1]]) -1)
bbox = TensorBBox((tensor([[1.,0., 2.,1]]) -1, ["nothing"]))
test_eq(flip(bbox, filt=0)[0], tensor([[1.,0., 0.,1]]) -1)
By default each of the 8 dihedral transformations (including noop) have the same probability of being picked when the transform is applied. You can customize this behavior by passing your own draw function. To force a specific flip, you can also pass an integer between 0 and 7.
_,axs = plt.subplots(2,4, figsize=(12,6))
for i,ax in enumerate(axs.flatten()):
dih = PILDihedral(p=1., draw=i)
show_image(dih(img, filt=0), ctx=ax, title=f'Flip {i}')
t = _pnt2tensor([[1,0], [2,1]], (3,3))
x = _tensor2pil(t)
for i in range(8):
dih = PILDihedral(p=1., draw=i)
y = dih(x, filt=0)
res = tensor(array(y))
pnts = TensorPoint(tensor([[1.,0.], [2,1]]) -1)
a,b = dih(pnts, filt=0), res.nonzero().flip(1).float() -1
assert equals(a,b) or equals(a,b.flip(0))
bb = torch.tensor([[-2,-0.5,0.5,1.5], [-0.5,-0.5,0.5,0.5], [1,0.5,0.5,0.75]])
bb,lbl = clip_remove_empty(bb, ['too big', 'normal', 'empty'])
test_eq(bb, torch.tensor([[-1,-0.5,0.5,1.], [-0.5,-0.5,0.5,0.5], [1,0.5,0.5,0.75]]))
test_eq(lbl, ['too big', 'normal', 0])
_,axs = plt.subplots(1,3,figsize=(12,4))
for ax,sz in zip(axs.flatten(), [300, 500, 700]):
pad = CropPad(sz)
show_image(pad(img), ctx=ax, title=f'Size {sz}');
_,axs = plt.subplots(1,3,figsize=(12,4))
for ax,mode in zip(axs.flatten(), [PadMode.Zeros, PadMode.Border, PadMode.Reflection]):
pad = CropPad((600,700), pad_mode=mode)
show_image(pad(img), ctx=ax, title=mode);
On the validation set, we take a center crop.
test_eq(ResizeMethod.Squish, 'squish')
size can be an integer (in which case images will be resized to a square) or a tuple. Depending on the method:
sizepad_mode When doing the resize, we use resamples[0] for images and resamples[1] for segmentation masks.
_,axs = plt.subplots(1,3,figsize=(12,4))
for ax,method in zip(axs.flatten(), [ResizeMethod.Squish, ResizeMethod.Pad, ResizeMethod.Crop]):
rsz = Resize(256, method=method)
show_image(rsz(img, filt=1), ctx=ax, title=method);
_,axs = plt.subplots(1,3,figsize=(12,4))
for ax,method in zip(axs.flatten(), [ResizeMethod.Squish, ResizeMethod.Pad, ResizeMethod.Crop]):
rsz = Resize(256, method=method)
show_image(rsz(img, filt=0), ctx=ax, title=method);
#TODO test
The crop picked as a random scale in range (min_scale,1) and ratio in the range passed, then the resize is done with resamples[0] for images and resamples[1] for segmentation masks. On the validation set, we center crop the image if it's ratio isn't in the range (to the minmum or maximum value) then resize.
crop = RandomResizedCrop(256)
_,axs = plt.subplots(3,3,figsize=(9,9))
for ax in axs.flatten():
cropped = crop(img)
show_image(cropped, ctx=ax);
#TODO: test
def _batch_ex(bs):
timg = tensor(array(img)).permute(2,0,1).float()/255.
return TensorImage(timg[None].expand(bs, *timg.shape))
Multipliy all the matrices returned by aff_fs before doing the corresponding affine transformation on a basic grid corresponding to size, then applies all coord_fs on the resulting flow of coordinates before finally doing an interpolation with mode and pad_mode.
x = flip_mat(torch.randn(100,4,3))
test_eq(set(x[:,0,0].numpy()), {-1,1}) #might fail with probability 2*2**(-100) (picked only 1s or -1s)
flip = Flip(p=1.)
t = _pnt2tensor([[1,0], [2,1]], (3,3))
y = flip(TensorImage(t[None,None]), filt=0)
test_eq(y, _pnt2tensor([[1,0], [0,1]], (3,3))[None,None])
pnts = TensorPoint((tensor([[1.,0.], [2,1]]) -1)[None])
test_eq(flip(pnts, filt=0), tensor([[[1.,0.], [0,1]]]) -1)
bbox = TensorBBox(((tensor([[1.,0., 2.,1]]) -1)[None], tensor([0.])[None]))
test_eq(flip(bbox, filt=0)[0], tensor([[[0.,0., 1.,1.]]]) -1)
x = torch.zeros(5,2,3)
def_draw = lambda: random.randint(0,7)
t = _draw_mask(x, def_draw)
assert (0. <= t).all() and (t <= 7).all()
t = _draw_mask(x, def_draw, 1)
assert (0. <= t).all() and (t <= 1).all()
test_eq(_draw_mask(x, def_draw, 1, p=1), tensor([1.,1,1,1,1]))
test_eq(_draw_mask(x, def_draw, [0,1,2,3,4], p=1), tensor([0.,1,2,3,4]))
draw can be specified if you want to customize which flip is picked when the transform is applied (default is a random number between 0 and 7). It can be an integer between 0 and 7, a list of such integers (which then should have a length equal to the size of the batch) or a callable that returns an integer between 0 and 7.
t = _batch_ex(8)
dih = Dihedral(p=1., draw=list(range(8)))
y = dih(t, filt=0)
_,axs = plt.subplots(2,4, figsize=(12,6))
for i,ax in enumerate(axs.flatten()):
show_image(y[i], ctx=ax, title=f'Flip {i}')
draw can be specified if you want to customize which angle is picked when the transform is applied (default is a random flaot between -max_deg and max_deg). It can be a float, a list of floats (which then should have a length equal to the size of the batch) or a callable that returns a float.
thetas = [-30,-15,0,15,30]
rot = Rotate(draw=thetas, p=1.)
y = rot(_batch_ex(5), filt=0)
_,axs = plt.subplots(1,5, figsize=(15,3))
for i,ax in enumerate(axs.flatten()):
show_image(y[i], ctx=ax, title=f'{thetas[i]} degrees')
draw, draw_x and draw_y can be specified if you want to customize which scale and center are picked when the transform is applied (default is a random float between 1 and max_zoom for the first, between 0 and 1 for the last two). Each can be a float, a list of floats (which then should have a length equal to the size of the batch) or a callbale that returns a float.
draw_x and draw_y are expected to be the position of the center in pct, 0 meaning the most left/top possible and 1 meaning the most right/bottom possible.
scales = [1., 1.1, 1.25, 1.5]
zoom = Zoom(draw=scales, p=1., draw_x=0.5, draw_y=0.5)
y = zoom(_batch_ex(4), filt=0)
fig,axs = plt.subplots(1,4, figsize=(12,3))
fig.suptitle('Center zoom with different scales')
for i,ax in enumerate(axs.flatten()):
show_image(y[i], ctx=ax, title=f'scale {scales[i]}')
zoom = Zoom(draw=1.5, p=1.)
y = zoom(_batch_ex(4), filt=0)
fig,axs = plt.subplots(1,4, figsize=(12,3))
fig.suptitle('Constant scale and different random centers')
for i,ax in enumerate(axs.flatten()):
show_image(y[i], ctx=ax)
draw_x and draw_y can be specified if you want to customize the magnitudes that are picked when the transform is applied (default is a random float between -magnitude and magnitude. Each can be a float, a list of floats (which then should have a length equal to the size of the batch) or a callable that returns a float.
scales = [-0.4, -0.2, 0., 0.2, 0.4]
warp = Warp(p=1., draw_y=scales, draw_x=0.)
y = warp(_batch_ex(5), filt=0)
fig,axs = plt.subplots(1,5, figsize=(15,3))
fig.suptitle('Vertical warping')
for i,ax in enumerate(axs.flatten()):
show_image(y[i], ctx=ax, title=f'magnitude {scales[i]}')
scales = [-0.4, -0.2, 0., 0.2, 0.4]
warp = Warp(p=1., draw_x=scales, draw_y=0.)
y = warp(_batch_ex(5), filt=0)
fig,axs = plt.subplots(1,5, figsize=(15,3))
fig.suptitle('Horizontal warping')
for i,ax in enumerate(axs.flatten()):
show_image(y[i], ctx=ax, title=f'magnitude {scales[i]}')
draw can be specified if you want to customize the magnitude that is picked when the transform is applied (default is a random float between -0.5*(1-max_lighting) and 0.5*(1+max_lighting). Each can be a float, a list of floats (which then should have a length equal to the size of the batch) or a callable that returns a float.
scales = [0.1, 0.3, 0.5, 0.7, 0.9]
bright = Brightness(p=1., draw=scales)
y = bright(_batch_ex(5), filt=0)
fig,axs = plt.subplots(1,5, figsize=(15,3))
for i,ax in enumerate(axs.flatten()):
show_image(y[i], ctx=ax, title=f'scale {scales[i]}')
draw can be specified if you want to customize the magnitude that is picked when the transform is applied (default is a random float taken with the log uniform distribution between (1-max_lighting) and 1/(1-max_lighting). Each can be a float, a list of floats (which then should have a length equal to the size of the batch) or a callable that returns a float.
scales = [0.65, 0.8, 1., 1.25, 1.55]
cont = Contrast(p=1., draw=scales)
y = cont(_batch_ex(5), filt=0)
fig,axs = plt.subplots(1,5, figsize=(15,3))
for i,ax in enumerate(axs.flatten()): show_image(y[i], ctx=ax, title=f'scale {scales[i]}')
#Affine only
tfms = [Rotate(draw=10., p=1), Zoom(draw=1.1, draw_x=0.5, draw_y=0.5, p=1.)]
comp = setup_aug_tfms([Rotate(draw=10., p=1), Zoom(draw=1.1, draw_x=0.5, draw_y=0.5, p=1.)])
test_eq(len(comp), 1)
x = torch.randn(4,3,5,5)
test_close(comp[0]._get_affine_mat(x),tfms[0]._get_affine_mat(x) @ tfms[1]._get_affine_mat(x))
#We can't test that the ouput of comp or the compostiionf of tfms on x is the same cause it's not (1 interpol vs 2 sp)
#Affine + lighting
tfms = [Rotate(), Zoom(), Warp(), Brightness(), Flip(), Contrast()]
comp = setup_aug_tfms(tfms)
test_eq(len(comp), 2)
test_eq(len(comp[0].aff_fs), 3)
test_eq(len(comp[0].coord_fs), 1)
test_eq(len(comp[1].fs), 2)
#Affine + lighting + others
tfms = [Rotate(), Zoom(), Warp(), Brightness(), Flip(), Contrast(), Cuda()]
comp = setup_aug_tfms(tfms)
test_eq(len(comp), 3)
test_eq(len(comp[0].aff_fs), 3)
test_eq(len(comp[0].coord_fs), 1)
test_eq(len(comp[1].fs), 2)
Random flip (or dihedral if flip_vert=True) with p=0.5 is added when do_flip=True. With p_affine we apply a random rotation of max_rotate degrees, a random zoom of maz_zoom and a perspective warping of max_warp. With p_ligthing we apply a change in brightness and contrast of max_lighting. Custon xtra_tfms can be added. size, mode and pad_mode will be used for the interpolation.
tfms = aug_transforms(pad_mode='zeros', max_lighting=0.5, max_warp=0.4)
y = _batch_ex(9)
for t in tfms: y = t(y, filt=0)
_,axs = plt.subplots(3,3, figsize=(12,9))
for i,ax in enumerate(axs.flatten()): show_image(y[i], ctx=ax)
camvid = untar_data(URLs.CAMVID_TINY)
fns = get_image_files(camvid)
cam_fn = fns[0]
mask_fn = camvid/'labels'/f'{cam_fn.stem}_P{cam_fn.suffix}'
def _cam_lbl(fn): return mask_fn
cam_dsrc = DataSource([cam_fn]*10, [PILImage.create, [_cam_lbl, PILMask.create]])
cam_tdl = TfmdDL(cam_dsrc.train, after_item=ToTensor(), after_batch=[Cuda(), ByteToFloatTensor(), *aug_transforms()], bs=9)
_,axs = plt.subplots(3,3, figsize=(9,9))
cam_tdl.show_batch(ctxs=axs.flatten(), vmin=1, vmax=30)
mnist = untar_data(URLs.MNIST_TINY)
fns = get_image_files(mnist)
mnist_fn = fns[0]
pnts = np.array([[0,0], [0,35], [28,0], [28,35], [9, 17]])
# def _pnt_open(fn): return PILImage.create(fn).resize((28,35))
def _pnt_lbl(fn)->None: return TensorPoint.create(pnts)
pnt_dsrc = DataSource([mnist_fn]*10, [[PILImage.create, Resize((35,28))], _pnt_lbl])
pnt_tdl = TfmdDL(pnt_dsrc.train, after_item=[PointScaler(), Resize(28), ToTensor()],
after_batch=[Cuda(), ByteToFloatTensor(), *aug_transforms(max_warp=0)], bs=9)
_,axs = plt.subplots(3,3, figsize=(9,9))
pnt_tdl.show_batch(ctxs=axs.flatten())
coco = untar_data(URLs.COCO_TINY)
images, lbl_bbox = get_annotations(coco/'train.json')
idx=2
coco_fn,bbox = coco/'train'/images[idx],lbl_bbox[idx]
def _coco_lbl(fn)->None: return BBox(bbox)
coco_dsrc = DataSource([coco_fn]*10, [PILImage.create, [_coco_lbl, BBoxCategorize()]])
coco_tdl = TfmdDL(coco_dsrc.train, after_item=[BBoxScaler(), ToTensor()],
after_batch=[Cuda(), ByteToFloatTensor(), *aug_transforms()], bs=9)
_,axs = plt.subplots(3,3, figsize=(9,9))
coco_tdl.show_batch(ctxs=axs.flatten())